{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "095c2c45",
   "metadata": {},
   "source": [
    "## Lesson 7 - LoRAX\n",
    "\n",
    "In this lesson, we explore how to efficiently Serve LLMs using LoRAX!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b1de77",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import json\n",
    "import time\n",
    "from typing import List\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from pydantic import BaseModel, constr\n",
    "\n",
    "from lorax import AsyncClient, Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4686ef7f",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import endpoint_url, headers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeceb612",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "client = Client(endpoint_url, headers=headers)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a838efbc",
   "metadata": {},
   "source": [
    "## Prefill vs Decode (KV Cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e6100dc",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "resp = client.generate(\"What is deep learning?\", max_new_tokens=32)\n",
    "duration_s = time.time() - t0\n",
    "\n",
    "print(resp.generated_text)\n",
    "print(\"\\n\\n----------\")\n",
    "print(\"Request duration (s):\", duration_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa960b0d",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "durations_s = []\n",
    "\n",
    "t0 = time.time()\n",
    "for resp in client.generate_stream(\"What is deep learning?\", max_new_tokens=32):\n",
    "    durations_s.append(time.time() - t0)\n",
    "    if not resp.token.special:\n",
    "        print(resp.token.text, sep=\"\", end=\"\", flush=True)\n",
    "    t0 = time.time()\n",
    "\n",
    "print(\"\\n\\n\\n----------\")\n",
    "print(\"Time to first token (TTFT) (s):\", durations_s[0])\n",
    "print(\"Throughput (tok / s):\", (len(durations_s) - 1) / sum(durations_s[1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2ec81c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "plt.plot(durations_s)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af60d44b",
   "metadata": {},
   "source": [
    "## Continuous Batching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9477d603",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "color_codes = [\n",
    "    \"31\",  # red\n",
    "    \"32\",  # green\n",
    "    \"34\",  # blue\n",
    "]\n",
    "\n",
    "\n",
    "def format_text(text, i):\n",
    "    return f\"\\x1b[{color_codes[i]}m{text}\\x1b[0m\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9cdf83c",
   "metadata": {
    "height": 403
   },
   "outputs": [],
   "source": [
    "async_client = AsyncClient(endpoint_url, headers=headers)\n",
    "\n",
    "\n",
    "durations_s = [[], [], []]\n",
    "\n",
    "\n",
    "async def run(max_new_tokens, i):\n",
    "    t0 = time.time()\n",
    "    async for resp in async_client.generate_stream(\"What is deep learning?\", max_new_tokens=max_new_tokens):\n",
    "        durations_s[i].append(time.time() - t0)\n",
    "        print(format_text(resp.token.text, i), sep=\"\", end=\"\", flush=True)\n",
    "        t0 = time.time()\n",
    "\n",
    "\n",
    "t0 = time.time()\n",
    "all_max_new_tokens = [100, 10, 10]\n",
    "await asyncio.gather(*[run(max_new_tokens, i) for i, max_new_tokens in enumerate(all_max_new_tokens)]) \n",
    "\n",
    "print(\"\\n\\n\\n----------\")\n",
    "print(\"Time to first token (TTFT) (s):\", [s[0] for s in durations_s])\n",
    "print(\"Throughput (tok / s):\", [(len(s) - 1) / sum(s[1:]) for s in durations_s])\n",
    "print(\"Total duration (s):\", time.time() - t0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e2f37bf",
   "metadata": {},
   "source": [
    "## Multi-LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f09a9304",
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "def run_with_adapter(prompt, adapter_id):\n",
    "    durations_s = []\n",
    "\n",
    "    t0 = time.time()\n",
    "    for resp in client.generate_stream(\n",
    "        prompt, \n",
    "        adapter_id=adapter_id,\n",
    "        adapter_source=\"hub\",\n",
    "        max_new_tokens=64,\n",
    "    ):\n",
    "        durations_s.append(time.time() - t0)\n",
    "        if not resp.token.special:\n",
    "            print(resp.token.text, sep=\"\", end=\"\", flush=True)\n",
    "        t0 = time.time()\n",
    "\n",
    "    print(\"\\n\\n\\n----------\")\n",
    "    print(\"Time to first token (TTFT) (s):\", durations_s[0])\n",
    "    print(\"Throughput (tok / s):\", (len(durations_s) - 1) / sum(durations_s[1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e191cbf",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "pt_hellaswag_processed = \\\n",
    "\"\"\"You are provided with an incomplete passage below. Please read the passage and then finish it with an appropriate response. For example:\n",
    "\n",
    "### Passage: My friend and I think alike. We\n",
    "\n",
    "### Ending: often finish each other's sentences.\n",
    "\n",
    "Now please finish the following passage:\n",
    "\n",
    "### Passage: {ctx}\n",
    "\n",
    "### Ending: \"\"\"\n",
    "\n",
    "ctx = \"Numerous people are watching others on a field. Trainers are playing frisbee with their dogs. the dogs\"\n",
    "\n",
    "\n",
    "run_with_adapter(pt_hellaswag_processed.format(ctx=ctx), adapter_id=\"predibase/hellaswag_processed\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0ad965",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "pt_cnn = \\\n",
    "\"\"\"You are given a news article below. Please summarize the article, including only its highlights.\n",
    "\n",
    "### Article: {article}\n",
    "\n",
    "### Summary: \"\"\"\n",
    "\n",
    "article = \"(CNN)Former Vice President Walter Mondale was released from the Mayo Clinic on Saturday after being admitted with influenza, hospital spokeswoman Kelley Luckstein said. \\\"He's doing well. We treated him for flu and cold symptoms and he was released today,\\\" she said. Mondale, 87, was diagnosed after he went to the hospital for a routine checkup following a fever, former President Jimmy Carter said Friday. \\\"He is in the bed right this moment, but looking forward to come back home,\\\" Carter said during a speech at a Nobel Peace Prize Forum in Minneapolis. \\\"He said tell everybody he is doing well.\\\" Mondale underwent treatment at the Mayo Clinic in Rochester, Minnesota. The 42nd vice president served under Carter between 1977 and 1981, and later ran for President, but lost to Ronald Reagan. But not before he made history by naming a woman, U.S. Rep. Geraldine A. Ferraro of New York, as his running mate. Before that, the former lawyer was  a U.S. senator from Minnesota. His wife, Joan Mondale, died last year.\"\n",
    "\n",
    "\n",
    "run_with_adapter(pt_cnn.format(article=article), adapter_id=\"predibase/cnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c908562c",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "pt_conllpp = \"\"\"\n",
    "Your task is a Named Entity Recognition (NER) task. Predict the category of\n",
    "each entity, then place the entity into the list associated with the \n",
    "category in an output JSON payload. Below is an example:\n",
    "\n",
    "Input: EU rejects German call to boycott British lamb . Output: {{\"person\":\n",
    "[], \"organization\": [\"EU\"], \"location\": [], \"miscellaneous\": [\"German\",\n",
    "\"British\"]}}\n",
    "\n",
    "Now, complete the task.\n",
    "\n",
    "Input: {inpt} Output:\"\"\"\n",
    "\n",
    "inpt = \"Only France and Britain backed Fischler 's proposal .\"\n",
    "\n",
    "\n",
    "run_with_adapter(pt_conllpp.format(inpt=inpt), adapter_id=\"predibase/conllpp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cc6add",
   "metadata": {
    "height": 556
   },
   "outputs": [],
   "source": [
    "durations_s = [[], [], []]\n",
    "\n",
    "\n",
    "async def run(prompt, adapter_id, i):\n",
    "    t0 = time.time()\n",
    "    async for resp in async_client.generate_stream(\n",
    "        prompt, \n",
    "        adapter_id=adapter_id,\n",
    "        adapter_source=\"hub\",\n",
    "        max_new_tokens=64,\n",
    "    ):\n",
    "        durations_s[i].append(time.time() - t0)\n",
    "        if not resp.token.special:\n",
    "            print(format_text(resp.token.text, i), sep=\"\", end=\"\", flush=True)\n",
    "        t0 = time.time()\n",
    "\n",
    "\n",
    "t0 = time.time()\n",
    "prompts = [\n",
    "    pt_hellaswag_processed.format(ctx=ctx),\n",
    "    pt_cnn.format(article=article),\n",
    "    pt_conllpp.format(inpt=inpt),\n",
    "]\n",
    "adapter_ids = [\"predibase/hellaswag_processed\", \"predibase/cnn\", \"predibase/conllpp\"]\n",
    "await asyncio.gather(*[run(prompt, adapter_id, i) \n",
    "                       for i, (prompt, adapter_id) in enumerate(zip(prompts, adapter_ids))]) \n",
    "\n",
    "print(\"\\n\\n\\n----------\")\n",
    "print(\"Time to first token (TTFT) (s):\", [s[0] for s in durations_s])\n",
    "print(\"Throughput (tok / s):\", [(len(s) - 1) / sum(s[1:]) for s in durations_s])\n",
    "print(\"Total duration (s):\", time.time() - t0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed936723",
   "metadata": {},
   "source": [
    "## Bonus: Structured Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00431323",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "from pydantic import BaseModel, constr\n",
    "\n",
    "class Person(BaseModel):\n",
    "    name: constr(max_length=10)\n",
    "    age: int\n",
    "\n",
    "schema = Person.schema()\n",
    "schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08b706aa",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "resp = client.generate(\n",
    "    \"Create a person description for me\", \n",
    "    response_format={\"type\": \"json_object\", \"schema\": schema}\n",
    ")\n",
    "json.loads(resp.generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15870540",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"\n",
    "Your task is a Named Entity Recognition (NER) task. Predict the category of\n",
    "each entity, then place the entity into the list associated with the \n",
    "category in an output JSON payload. Below is an example:\n",
    "\n",
    "Input: EU rejects German call to boycott British lamb . Output: {{\"person\":\n",
    "[], \"organization\": [\"EU\"], \"location\": [], \"miscellaneous\": [\"German\",\n",
    "\"British\"]}}\n",
    "\n",
    "Now, complete the task.\n",
    "\n",
    "Input: {input} Output:\"\"\"\n",
    "\n",
    "# Base Mistral-7B\n",
    "resp = client.generate(\n",
    "    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),  \n",
    "    max_new_tokens=128,\n",
    ")\n",
    "resp.generated_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65923d86",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "class Output(BaseModel):\n",
    "    person: List[str]\n",
    "    organization: List[str]\n",
    "    location: List[str]\n",
    "    miscellaneous: List[str]\n",
    "\n",
    "schema = Output.schema()\n",
    "schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29a1eb2",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "resp = client.generate(\n",
    "    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),\n",
    "    response_format={\n",
    "        \"type\": \"json_object\",\n",
    "        \"schema\": schema,\n",
    "    },\n",
    "    max_new_tokens=128,\n",
    ")\n",
    "json.loads(resp.generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c56fca8",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "resp = client.generate(\n",
    "    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),\n",
    "    adapter_id=\"predibase/conllpp\",\n",
    "    adapter_source=\"hub\",\n",
    "    max_new_tokens=128,\n",
    ")\n",
    "json.loads(resp.generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee30357c",
   "metadata": {},
   "source": [
    "## Additional Predibase Resources: \n",
    "\n",
    "- LoRA Land demo environment: https://predibase.com/lora-land \n",
    "- Predibase blog articles: https://predibase.com/blog \n",
    "- LoRAX Github repo: https://github.com/predibase/lorax \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
